{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Chapter 9.5 — Text Generation With GPT-2 And (only) PyTorch\n",
    "\n",
    "While I'm mostly happy with how the book turned out, bar some silly errors that should not have made it to print and needing about another six months to do it properly (although work would have precluded that, so…anyway), I was a little disappointed with how I handled text generation. It worked, that's for sure, but it was little more than 'run this program on this text, then run this script to transform the Tensorflow model into a PyTorch compatible format, and run _this_ script to generate output'. And then, to top it all off, about a week after the book went to print, the repo that housed most of the code underwent a major change from `pytorch-pretrained-BERT` to its eventual name of `transformers`. A bit of a pain.\n",
    "\n",
    "In a way to make that up to people, welcome to Chapter 9.5 - A Half-Chapter in Two Parts. In this part, we'll take another look at text generation, but this time, we won't leave PyTorch. Promise. In Part Two (or is that Chapter 9.75?), we'll have a bit of a final look back at images. The common theme between both parts will be self-supervision and domain modelling. I don't have an ETA for Part Two yet, but it'll come, promise.\n",
    "\n",
    "If you're looking for a refresher on the Transformer architecture, then there's some in Chapter 9 of my book, but more usefully, you could go here to read [The Illustrated Transformer](http://jalammar.github.io/illustrated-transformer/), and here for [The Illustrated GPT-2](http://jalammar.github.io/illustrated-gpt2/).\n",
    "\n",
    "## Adding New Generation Tricks To GPT-2\n",
    " \n",
    "Right, so if you remember in the book, we went on a jolly side-jaunt with P.G. Wodehouse. And that was all very fine and whimsical, but maybe we want something that shows off the capabilities of GPT-2 a little better, even if it's really just doing most of the same thing under the covers.\n",
    "\n",
    "Instead of Jeeves and Wooster, we're going to generate tweets. And we're going to take things a step further by adding a new \"control code\" to our fine-tuned GPT-2 model, so we can instruct GPT-2 that we specifically want to generate a new tweet. If we don't add the control code, then we should just get a (mostly) standard GPT-2 output. And we can use this technique to add _multiple_ control codes, so if you had different sets of synthetic data that you wish to generate, you can use those codes to determine which type to create.\n",
    " \n",
    "And first…let's go back to the standard thing we always do. \n",
    "\n",
    "_\"Gee Brain, what are we going to do tonight?\"_\n",
    "_\"The same thing we do every night Pinky. Write a new custom dataset and take over the world!\"_\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Using PyTorch 1.4\n",
    "\n",
    "import numpy as np\n",
    "import pyarrow.parquet as pq\n",
    "import pandas as pd\n",
    "import random\n",
    "import torch\n",
    "import fire\n",
    "import logging\n",
    "import os\n",
    "import csv\n",
    "\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup\n",
    "from tqdm import tqdm, trange\n",
    "import torch.nn.functional as F\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Datasets\n",
    " \n",
    "Don't worry though, we won't be doing anything too crazy with this `Dataset`. \n",
    "\n",
    "Much. \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ParquetDataset(Dataset):\n",
    "    def __init__(self, path, cols, truncate=False, gpt2_type=\"gpt2\", max_length=768):\n",
    "\n",
    "        # Grab our pandas dataframe, only reading in the columns we're interested in,\n",
    "        # append our magic tokens (<#col_name#> for the particular column, and <|endoftext|>\n",
    "        # used by GPT-2 as a text separator), then concatenate them into one giant column for\n",
    "        # our dataset\n",
    "\n",
    "        self.tokenizer = GPT2Tokenizer.from_pretrained(gpt2_type)\n",
    "        \n",
    "        self.df = pq.read_table(path, columns=cols).to_pandas().dropna()\n",
    "        for col in cols:\n",
    "            self.df[col] = self.df[col].apply(lambda x: torch.tensor(self.tokenizer.encode(f\"<#{col}#>{x[:768]}<|endoftext|>\")))\n",
    "        self.df = pd.concat(map(self.df.get, cols)).reset_index(drop=True)\n",
    "        if truncate:\n",
    "            self.df = self.df.truncate(after=150)\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.df.count()\n",
    "\n",
    "    def __getitem__(self, item):\n",
    "        return self.df.iloc[item]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CSVTwitter(Dataset):\n",
    "    \n",
    "    def __init__(self, control_code, truncate=False, gpt2_type=\"gpt2\", max_length=768):\n",
    "\n",
    "        self.tokenizer = GPT2Tokenizer.from_pretrained(gpt2_type)\n",
    "        self.tweets = []\n",
    "\n",
    "        # This uses the same CSV of Sentiment140 that we created in Chapter 5\n",
    "        \n",
    "        with open('train-processed.csv', newline='') as csvfile:\n",
    "            tweet_csv = csv.reader(csvfile)\n",
    "            for row in tweet_csv:\n",
    "                self.tweets.append(torch.tensor(\n",
    "                    self.tokenizer.encode(f\"<|{control_code}|>{row[5][:max_length]}<|endoftext|>\")\n",
    "                ))\n",
    "                \n",
    "        if truncate:\n",
    "            self.tweets = self.tweets[:20000]\n",
    "        self.tweet_count = len(self.tweets)\n",
    "        \n",
    "    def __len__(self):\n",
    "        return self.tweet_count\n",
    "\n",
    "    def __getitem__(self, item):\n",
    "        return self.tweets[item]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Firstly, you might wonder is why we're ensuring that we chop our strings at 768 characters. We're going to be using `gpt2-small` in this chapter, which has that limitation due to its hidden dimensionality of 768 (if you want to use larger pre-trained models, then you can increase this: `gpt2-medium`/1024, `gpt2-large`/1280, `gpt2-xl`/1600). Of course, because this dataset is only tweets, we're never going to bump up against the limit, but I thought I would I'd include it so you know to be aware of the limitation. \n",
    " \n",
    "You'll also see that we're injecting our `<|tweet|>` control code at the start of each entry, and the `<|endoftext|>` code at the end - this is actually a code that GPT-2 has already learnt during its initial training to signify the end of a piece of text. It'll become useful later on in training when we pack our training tensors.\n",
    "\n",
    "The last part of the dataset is _encoding_. This is similar to the encoding of text that we did back in Chapter 5, but with a small twist. Instead of a simple mapping of all words to a new dictionary, we are using a _byte pair encoding tokenizer_. This works in a different way to what we have seen before as it builds a dictionary by keeping track of common pairs of bytes and replaces them with a byte that is not present in the encoding. \n",
    "\n",
    "For example, take the nonsense string:\n",
    "\t\n",
    "\taabaabdeaa\n",
    "\t\t\n",
    "The first pass of the byte pair encoder would replace our `aa` strings:\n",
    "\n",
    "\tAbAbdeA\n",
    "\tA = aa\n",
    "\n",
    "But note that we now have new byte pairs and so we can replace again:\n",
    "\n",
    "\tBBdeA\n",
    "\tA = aa\n",
    "\tB = Ab\n",
    "\n",
    "For building up a vocabulary from our data, the byte pair encoding in language models these days tends to work in the opposite direction; it starts out with a set of characters in that language, and through passes on the data, builds up _subwords_ by finding the pairs present in the dataset, and then merging to find larger pairs, and so on. In this way, the tokenizer learns a vocabulary directly from the dataset itself and not from any manual input from an external source (like us).\n",
    "\n",
    "Happily, we can use the BPE tokenizer that has already been trained on the dataset of GPT-2 and not have to worry about training it ourselves here (though if you're looking to train on a new language, [Huggingface's tutorial on learning Esperanto](https://huggingface.co/blog/how-to-train) will tell you everything you need to get started). We create a pre-trained version using ` GPT2Tokenizer.from_pretrained(gpt2_type)`, which will download the appropriate files for the version of GPT-2 we're working with. We then encode the dataset and create tensors, returning a particular tensor within `__getitem__()` as normal.\n",
    "\n",
    "In addition to the CSV-based `Dataset`, I've also included a different implementation that uses PyArrow to load in named columns from a parquet file. I just had a bunch of parquet-based datasets lying around so it was useful to make a class that could handle them as well.\n",
    " \n",
    "We'll build a `DataLoader` in our usual way:\n",
    "\n",
    "    DataLoader(dataset, batch_size=1, shuffle=True)  \n",
    "    \n",
    "(the reason for `batch_size` being 1 is something we'll come back to later)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training\n",
    " \n",
    " Okay, so how do we train this thing? Well, it turns out that it's actually a lot more simple than you'd think. We already have a pre-trained model, so we're just doing some fine-tuning (we won't freeze layers here, but you can certainly experiment with it). But…don't we need labels? \n",
    " \n",
    "Training GPT-2's involves passing our input text into the transformer model…and training the model to get the text back as output. In this way, the model learns the something of how text is structured, and eventually builds up a _language model_ that can be used for generating further text. So our labels are the input text! \n",
    "\n",
    "To get the model to produce anything resembling English or whatever language you're training it on requires a gargantuan amount of text (OpenAI trained GPT-2 on 8 million webpages). But as we're using a pre-trained model, all that hard work has been done for us, so we can get away with a much smaller dataset. We can create a pre-trained GPT-2 transformer with one line of code:\n",
    "\n",
    "\tmodel = GPT2LMHeadModel.from_pretrained(gpt2_type)\n",
    "\n",
    "As for our training loop, given that our labels are our input, all we're really doing is:\n",
    "\n",
    "\toutputs = model(input)\n",
    "\tloss = loss_function(output, input)\n",
    "\tloss.backward()\n",
    "\toptimizer.step()\n",
    "\n",
    "But there's a slight catch. You remember that GPT-2 is big, right? Very big. It's quite possible that you can't fit all the parameters and all the gradient updates inside your GPU. I know I can't, and I have a 1080Ti. There's various approaches we can use to get around this problem, like distributed training, or maybe gradient checkpointing (covered in Chapter 7).\n",
    "\n",
    "However, there's a simpler option we can use . What we're going to do is _accumulate_ our gradients for a number of batches and then do the updating every _x_ batches instead of every batch. We'll divide our loss updates by the `accumulated_batch_size` to average out the loss that we're applying.\n",
    "\n",
    "We're almost at the point of having the training loop sorted. But what's that, Columbo?\n",
    "\n",
    "You may have looked at the links to the illustrated Transformer articles and discovered that GPT-2 will 'see' all of its input at once. And we're sending in encoded tensors of 140-character strings. That's leaving a lot of our input set to…basically zero. Is that going to be great for training? Probably not, as we're not going to get a lot of information flowing forwards and backwards through our network. Enter…`pack_tensor()`!\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pack_tensor(new_tensor, packed_tensor, max_seq_len):\n",
    "    if packed_tensor is None:\n",
    "        return new_tensor, True, None\n",
    "    if new_tensor.size()[1] + packed_tensor.size()[1] > max_seq_len:\n",
    "        return packed_tensor, False, new_tensor\n",
    "    else:\n",
    "        packed_tensor = torch.cat([new_tensor, packed_tensor[:, 1:]], dim=1)\n",
    "        return packed_tensor, True, None"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is a very simple method that just tries to fit as many pieces of text into an input tensor as possible. This is why we created the DataLoader with a `batch_size` of 1, as in our training loop, we'll simply loop over and over the data until we've stuffed a tensor, and then push it through our model. Of course, this breaks the relationship between batches that come from the `Dataset` and what we send to the model for the training, so we add `accumulating_batch_count` as a counter to work out when we need to train on our accumulated gradients.\n",
    "\n",
    "You'll also notice in the train() code below that instead of our normal patten of:\n",
    "\toutputs = model(input)\n",
    "\tloss = loss_function(output, input)\n",
    "\n",
    "We're actually doing:\n",
    "\n",
    "\toutputs = model(input, labels=input)\n",
    "     loss = outputs[0]\n",
    "\n",
    "There's nothing too nefarious going on here; the GPT-2 model simply has code inside it that calculates the loss to make things easier. [It's just a simple CrossEntropyLoss as we've seen in previous chapters]().\t\n",
    "\n",
    "Our optimizer and learning rate also come from the `transformers` library, and we're using the AdamW ([Adam + Weight Decay](https://www.fast.ai/2018/07/02/adam-weight-decay/)) optimizer with a warmup and linear decay (you can see alternatives at [Huggingface's docs page](https://huggingface.co/transformers/main_classes/optimizer_schedules.html)). Plus we also include the ability to save a set of weights at the end of an epoch."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(\n",
    "    dataset,\n",
    "    model,\n",
    "    tokenizer,\n",
    "    batch_size=16,\n",
    "    epochs=4,\n",
    "    lr=2e-5,\n",
    "    max_seq_len=400,\n",
    "    warmup_steps=5000,\n",
    "    gpt2_type=\"gpt2\",\n",
    "    device=\"cuda\",\n",
    "    output_dir=\".\",\n",
    "    output_prefix=\"wreckgar\",\n",
    "    test_mode=False,\n",
    "    save_model_on_epoch=False,\n",
    "):\n",
    "\n",
    "    acc_steps = 100\n",
    "\n",
    "    model = model.to(device)\n",
    "    model.train()\n",
    "\n",
    "    optimizer = AdamW(model.parameters(), lr=lr)\n",
    "    scheduler = get_linear_schedule_with_warmup(\n",
    "        optimizer, num_warmup_steps=warmup_steps, num_training_steps=-1\n",
    "    )\n",
    "\n",
    "    train_dataloader = DataLoader(dataset, batch_size=1, shuffle=True)\n",
    "\n",
    "    accumulating_batch_count = 0\n",
    "    input_tensor = None\n",
    "\n",
    "    for epoch in range(epochs):\n",
    "\n",
    "        print(f\"Training epoch {epoch}\")\n",
    "        for idx, entry in tqdm(enumerate(train_dataloader)):\n",
    "            (input_tensor, carry_on, remainder) = pack_tensor(entry, input_tensor, 768)\n",
    "\n",
    "            if carry_on and idx != len(train_dataloader) - 1:\n",
    "                continue\n",
    "\n",
    "            input_tensor = input_tensor.to(device)\n",
    "            outputs = model(input_tensor, labels=input_tensor)\n",
    "            loss = outputs[0]\n",
    "            loss.backward()\n",
    "\n",
    "            if (accumulating_batch_count % batch_size) == 0:\n",
    "                optimizer.step()\n",
    "                scheduler.step()\n",
    "                optimizer.zero_grad()\n",
    "                model.zero_grad()\n",
    "\n",
    "            accumulating_batch_count += 1\n",
    "            input_tensor = None\n",
    "        if save_model_on_epoch:\n",
    "            torch.save(\n",
    "                model.state_dict(),\n",
    "                os.path.join(output_dir, f\"{output_prefix}-{epoch}.pt\"),\n",
    "            )\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = CSVTwitter(\"<|tweet|>\", truncate=True, gpt2_type=\"gpt2\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gpt2_type = \"gpt2\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "model = train(\n",
    "    dataset,\n",
    "    GPT2LMHeadModel.from_pretrained(gpt2_type),\n",
    "    GPT2Tokenizer.from_pretrained(gpt2_type),\n",
    "    batch_size=16,\n",
    "    epochs=1,\n",
    "    lr=3e-5,\n",
    "    max_seq_len=140,\n",
    "    warmup_steps=5000,\n",
    "    gpt2_type=gpt2_type,\n",
    "    device=\"cuda\",\n",
    "    output_dir=\"trained_models\",\n",
    "    output_prefix=\"twitter\",\n",
    "    save_model_on_epoch=True\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Generating Text\n",
    " \n",
    "For generating text from our fine-tuned model, there are multiple approaches that we could use, including _beam search_, *top_k filtering*, and the one we're going to use — _nucleus sampling_ (or *top_p filtering*). We take our input, in this case our new control code `<|tweet|>` and then feed that into the model to generate a new sequence. But all we care about it is the next word, and in particular, the probabilities of all the possible words that the model predicts should appear there. \n",
    "\n",
    "Of course, lots of words that the model may predict will not make sense, and that's where we can bring in _nucleus sampling_ (or *top_k* or any other approach). In this approach, we sum up all the probabilities, sorted in descending order that are present *until* the total sum (the cumulative distribution function) is above an adjustable hyperparameter, `p`, which is normally set between 0.7 and 0.9. There's another parameter, `temperature`, which can be used to scale the probabilities before they're summed up into the CDF. \n",
    "\n",
    "Once the CDF is formed, we eliminate everything that falls outside of our `p` by setting it to `-Infinity`. We're not messing around here. Note that as we're doing this by summing the highest probability selections first, it's possible that if there's a few high probability choices, they'll be the only ones present. And that makes sense if you think about sentences like:\n",
    "\t\n",
    "\tThe dog lifted up its ____\n",
    "\t\n",
    "Possible options here could include `paw, tail, tongue`. You'd  expect `paw` or `tail` much more than `tongue`. In this way, our sampling feels more natural, while still providing the possibility for surprise when probabilities are more spread out.\n",
    "\n",
    "Most of the code here is taken from Huggingface's [`run_generation.py` script](https://github.com/huggingface/transformers/blob/master/examples/run_generation.py). \n",
    "\n",
    "Once we have our next word, we loop back around to the start, but this time we feed in the sentence with the new word added and choose the following word in the same way. We continue until we either reach `entry_length` or if the model generates a `<|endoftext|>` marker. And then it's back to the outer loop to generate our next sentence until we've generated the requested number of sentences.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate(\n",
    "    model,\n",
    "    tokenizer,\n",
    "    prompt,\n",
    "    entry_count=10,\n",
    "    entry_length=100,\n",
    "    top_p=0.8,\n",
    "    temperature=1.,\n",
    "):\n",
    "\n",
    "    model.eval()\n",
    "\n",
    "    generated_num = 0\n",
    "    generated_list = []\n",
    "\n",
    "    filter_value = -float(\"Inf\")\n",
    "\n",
    "    with torch.no_grad():\n",
    "\n",
    "        for entry_idx in trange(entry_count):\n",
    "\n",
    "            entry_finished = False\n",
    "\n",
    "            generated = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0)\n",
    "\n",
    "            # Using top-p (nucleus sampling): https://github.com/huggingface/transformers/blob/master/examples/run_generation.py\n",
    "\n",
    "            for i in range(entry_length):\n",
    "                outputs = model(generated, labels=generated)\n",
    "                loss, logits = outputs[:2]\n",
    "                logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)\n",
    "\n",
    "                sorted_logits, sorted_indices = torch.sort(logits, descending=True)\n",
    "                cumulative_probs = torch.cumsum(\n",
    "                    F.softmax(sorted_logits, dim=-1), dim=-1\n",
    "                )\n",
    "\n",
    "                sorted_indices_to_remove = cumulative_probs > top_p\n",
    "                sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[\n",
    "                    ..., :-1\n",
    "                ].clone()\n",
    "                sorted_indices_to_remove[..., 0] = 0\n",
    "\n",
    "                indices_to_remove = sorted_indices[sorted_indices_to_remove]\n",
    "                logits[:, indices_to_remove] = filter_value\n",
    "\n",
    "                next_token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)\n",
    "                generated = torch.cat((generated, next_token), dim=1)\n",
    "\n",
    "                if next_token in tokenizer.encode(\"<|endoftext|>\"):\n",
    "                    entry_finished = True\n",
    "\n",
    "                if entry_finished:\n",
    "\n",
    "                    generated_num = generated_num + 1\n",
    "\n",
    "                    output_list = list(generated.squeeze().numpy())\n",
    "                    output_text = tokenizer.decode(output_list)\n",
    "\n",
    "                    generated_list.append(output_text)\n",
    "                    break\n",
    "            \n",
    "            if not entry_finished:\n",
    "                output_list = list(generated.squeeze().numpy())\n",
    "                output_text = f\"{tokenizer.decode(output_list)}<|endoftext|>\" \n",
    "                generated_list.append(output_text)\n",
    "                \n",
    "    return generated_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "generated_tweets = generate(model.to('cpu'), GPT2Tokenizer.from_pretrained(gpt2_type),\"<|tweet|>\",entry_count=10)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Example Output\n",
    "\n",
    "And here's some output of calling `generate` on our trained model.\n",
    "\n",
    "\t\"<|tweet|>Casa the fifth Monday afternoons in the summer. Stay for one more - you'll be much better at finding a workplace than you would at the \toffice.\\n\\nThe Hours\\n\\n14:00 - 15:00, Hot and Cold\\n\\n18:00 - 19:00, Cafe Oktoberfest\\n\\n19:00 - 21:00, More Information<|endoftext|>\",\n",
    " \t'<|tweet|>Tweet what you like.<|endoftext|>',\n",
    "\t'<|tweet|>Sigh. Hope to see ya in there.<|endoftext|>',\n",
    " \t'<|tweet|> | The Walking Dead ends, '10 hours after everybody gets killed! I'm sick of zombies. pic.twitter.com/tsxhXdGLuGx.<|endoftext|>'\n",
    "  \n",
    " \n",
    " ### Further Techniques & Reading\n",
    " \n",
    "[Huggingface](https://huggingface.co/)\n",
    "\n",
    "[Better Language Models and Their Implications (GPT-2)](https://openai.com/blog/better-language-models/)\n",
    "\n",
    "[Applying BERT-based models in Search](https://www.blog.google/products/search/search-language-understanding-bert/)\n",
    "\n",
    "[How To Sample From Language Models](https://towardsdatascience.com/how-to-sample-from-language-models-682bceb97277)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
